# -*- coding: utf-8 -*-
"""Code_Submission.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1zhVPBOlx3xxl83_3h4blgs42zeCXKBol
"""

"""
Matrix Product Approximation Experiments
----------------------------------------
This module implements various algorithms for approximating matrix products
and compares their performance against theoretical bounds.

Author: Your Name
License: MIT
"""

import numpy as np
import matplotlib.pyplot as plt
import os
import time
import scipy
import scipy.linalg
import cvxpy as cp
import pandas as pd
from tqdm import tqdm
import warnings
from typing import Dict, Any, List, Optional, Tuple
import itertools
import math
import traceback

# Create necessary directories
os.makedirs("plots", exist_ok=True)
os.makedirs("results", exist_ok=True)

# =============================================================================
# VISUALIZATION SETTINGS
# =============================================================================

# High-quality plot styles for publication
PLOT_STYLES = {
    # Optimal Error v_k^*
    'Optimal Error v_k^*': {
        'color': 'gold', 'marker': '*', 'linestyle': '-', 'label': r'Optimal $v_k^*$',
        'lw': 4.0, 'markersize': 16, 'zorder': 10, 'markeredgewidth': 1.5, 'markeredgecolor': 'black'
    },
    # User Bounds
    'Your Bound (QP CVXPY Best)': {
        'color': 'black', 'marker': 'o', 'linestyle': '-', 'label': 'Bound (QP Best)',
        'lw': 3.5, 'markersize': 12, 'zorder': 9, 'markeredgewidth': 1.0
    },
    'Your Bound (QP Analytical)': {
        'color': 'dimgrey', 'marker': '^', 'linestyle': ':', 'label': 'Bound (QP Approx)',
        'lw': 3.0, 'markersize': 12, 'zorder': 8, 'markeredgewidth': 1.0
    },
    'Your Bound (Binary)': {
        'color': 'darkgrey', 'marker': 's', 'linestyle': '--', 'label': 'Bound (Binary)',
        'lw': 3.0, 'markersize': 12, 'zorder': 7, 'markeredgewidth': 1.0
    },
    # Standard Theoretical Bounds
    'Bound (Leverage Score Exp.)': {
        'color': 'deepskyblue', 'marker': 'D', 'linestyle': '-.', 'label': 'Bound (Lev. Score Exp.)',
        'lw': 3.0, 'markersize': 12, 'zorder': 6, 'markeredgewidth': 1.0, 'markeredgecolor': 'navy'
    },
    'Bound (Sketching Simple)': {
        'color': 'sandybrown', 'marker': 'P', 'linestyle': ':', 'label': 'Bound (Sketching Simple)',
        'lw': 3.0, 'markersize': 12, 'zorder': 5, 'markeredgewidth': 1.0, 'markeredgecolor': 'saddlebrown'
    },
    # Algorithms
    'Error Leverage Score (Actual)': {
        'color': 'blue', 'marker': 'x', 'linestyle': '-', 'label': 'Leverage Score Sampling',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 4, 'markeredgewidth': 2.0
    },
    'Error CountSketch (Actual)': {
        'color': 'orange', 'marker': 'd', 'linestyle': '--', 'label': 'CountSketch',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 3, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkorange'
    },
    'Error SRHT (Actual)': {
        'color': 'red', 'marker': 'v', 'linestyle': '-.', 'label': 'SRHT',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 2, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkred'
    },
    'Error Gaussian (Actual)': {
        'color': 'darkviolet', 'marker': '<', 'linestyle': ':', 'label': 'Gaussian Proj.',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 1, 'markeredgewidth': 1.5, 'markeredgecolor': 'indigo'
    },
    'Error Greedy OMP (Actual)': {
        'color': 'forestgreen', 'marker': '>', 'linestyle': '-', 'label': 'Greedy OMP',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 0, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkgreen'
    },
}

# Configure matplotlib for publication-quality plots
plt.rcParams.update({
    'font.size': 20,               # Base font size
    'axes.titlesize': 24,          # Font size for subplot titles
    'axes.labelsize': 22,          # Font size for x and y labels
    'xtick.labelsize': 20,         # Font size for x-axis tick labels
    'ytick.labelsize': 20,         # Font size for y-axis tick labels
    'legend.fontsize': 20,         # Font size for legends
    'figure.titlesize': 26,        # Font size for the figure's suptitle
    'figure.figsize': (18, 8),     # Default figure size
    'figure.dpi': 150,             # Higher DPI for better quality
    'savefig.dpi': 300,            # Even higher DPI for saved figures
    'lines.linewidth': 3,          # Default line width
    'lines.markersize': 12,        # Default marker size
    'axes.linewidth': 1.5,         # Width of the axes lines
    'grid.linewidth': 1.0,         # Width of the grid lines
    'axes.grid': True,             # Show grid by default
    'grid.alpha': 0.3,             # Grid transparency
    'axes.titleweight': 'bold',    # Bold subplot titles
    'axes.labelweight': 'bold',    # Bold axis labels
    'figure.titleweight': 'bold',  # Bold figure title
    'mathtext.default': 'regular', # Math text style
    'mathtext.fontset': 'cm',      # Computer Modern math font
})

# --- Basic Helper Functions ---

def frob_norm_sq(M: np.ndarray) -> float:
    """Computes the squared Frobenius norm of a matrix."""
    M = np.asarray(M)
    # Use float64 for potentially better precision in norm calculation
    return np.linalg.norm(M.astype(np.float64), 'fro')**2

def col_norms_sq(M: np.ndarray) -> np.ndarray:
    """Computes the squared Euclidean norm of each column."""
    M = np.asarray(M)
    # Use float64 for potentially better precision
    return np.linalg.norm(M.astype(np.float64), axis=0)**2

def calculate_rho_g(A: np.ndarray, B: np.ndarray) -> float:
    """Calculates the Rho_G metric for the matrix G = (A^T A) * (B^T B)."""
    try:
        A = np.asarray(A, dtype=np.float64)
        B = np.asarray(B, dtype=np.float64) # B is p x n here
        m, n = A.shape
        p, n2 = B.shape
        if n != n2: raise ValueError("Dimension mismatch in calculate_rho_g (A vs B)")
        if n == 0: return 0.0

        AtA = A.T @ A
        BtB = B.T @ B
        G = AtA * BtB # Element-wise product

        trace_G = np.trace(G)
        sum_G = np.sum(G)

        # Handle edge cases for rho calculation
        if sum_G <= 1e-12:
             if np.linalg.norm(G, 'fro') < 1e-12: return 0.0 # G is essentially zero
             if trace_G > 1e-12: return np.inf # Trace is positive, sum is zero -> inf
             return np.inf # Default to inf if sum is zero but trace isn't clearly positive

        rho = trace_G / sum_G
        return max(0, rho) # Ensure non-negative
    except Exception as e:
        warnings.warn(f"Error calculating Rho_G: {e}")
        traceback.print_exc()
        return np.nan

# --- Matrix Generation ---

def generate_matrices(m: int, p: int, n: int,
                      cancellation_pairs: int = 0,
                      noise_level: float = 0.0,
                      seed: Optional[int] = None,
                      distribution: str = 'gaussian') -> Tuple[np.ndarray, np.ndarray]:
    """Generates matrices A (m x n) and B (p x n)."""
    if seed is not None:
        np.random.seed(seed)

    if distribution == 'gaussian':
        A = np.random.randn(m, n)
        B = np.random.randn(p, n)
    elif distribution == 'uniform':
        A = np.random.rand(m, n) * 2 - 1
        B = np.random.rand(p, n) * 2 - 1
    else:
        raise ValueError("Unsupported distribution")

    # Introduce cancellation
    if cancellation_pairs > 0 and n >= 2 * cancellation_pairs:
        cancellation_pairs = min(cancellation_pairs, n // 2)
        indices = np.random.choice(n, 2 * cancellation_pairs, replace=False)
        for i in range(cancellation_pairs):
            idx1, idx2 = indices[2*i], indices[2*i+1]
            # Make columns dependent with opposite signs in B
            A[:, idx2] = A[:, idx1]
            B[:, idx2] = -B[:, idx1]
            # Add some scaling variation
            scale_factor = np.random.uniform(0.5, 1.5)
            A[:, idx1] *= scale_factor; A[:, idx2] *= scale_factor
            B[:, idx1] *= scale_factor; B[:, idx2] *= scale_factor

    # Add noise
    if noise_level > 0:
        A += np.random.normal(0, noise_level * np.std(A) if np.std(A) > 1e-9 else noise_level, size=(m, n))
        B += np.random.normal(0, noise_level * np.std(B) if np.std(B) > 1e-9 else noise_level, size=(p, n))

    return A, B


def generate_matrices_for_rho(target_rho: float, m: int, p: int, n: int,
                              tolerance: float = 0.25,
                              max_attempts: int = 2500,
                              base_seed: int = 0) -> Optional[Tuple[np.ndarray, np.ndarray, float]]:
    """Attempts to generate matrices A, B yielding a Rho_G close to target_rho."""
    print(f"Attempting to generate matrices for target Rho_G ≈ {target_rho:.2f} (n={n})...")
    for attempt in range(max_attempts):
        current_seed_val = base_seed + attempt
        valid_seed = current_seed_val % (2**32) # Ensure seed is within valid range
        np.random.seed(valid_seed)

        # Heuristics for cancellation/noise based on target rho
        if target_rho <= 1.0: # Low rho: less cancellation, maybe small noise
            if np.random.rand() < 0.7: canc_pairs = 0
            else: canc_pairs = np.random.randint(0, max(1, n // 20))
            noise = np.random.uniform(0, 0.05)
        elif target_rho <= 10.0: # Medium rho: moderate cancellation/noise
            canc_pairs = np.random.randint(max(1, n // 10), max(2, n // 4))
            noise = np.random.uniform(0.01, 0.1)
        else: # High rho: more cancellation/noise
            canc_pairs = np.random.randint(max(1, n // 4), max(2, n // 2 + 1))
            noise = np.random.uniform(0.05, 0.15)

        A_gen, B_gen = generate_matrices(m, p, n,
                                         cancellation_pairs=canc_pairs,
                                         noise_level=noise,
                                         seed=None) # Use the seed set above

        current_rho = calculate_rho_g(A_gen, B_gen)

        if np.isnan(current_rho) or np.isinf(current_rho): continue # Skip invalid rho

        # Check if rho is within tolerance
        if target_rho > 1e-6: # Use relative tolerance for non-zero target
            is_close = abs(current_rho - target_rho) / target_rho < tolerance
        else: # Use absolute tolerance for target near zero
            is_close = abs(current_rho - target_rho) < tolerance

        if is_close:
            print(f"  Success! Found Rho_G = {current_rho:.3f} (Target ≈ {target_rho:.2f}) after {attempt + 1} attempts.")
            return A_gen, B_gen, current_rho

    warnings.warn(f"Failed to generate matrices for target Rho_G ≈ {target_rho:.2f} within {max_attempts} attempts.")
    return None


# --- USER'S BOUND FUNCTION ---
def compute_theoretical_bounds(data: Dict[str, Any], k: int) -> Dict[str, Any]:
    """
    Compute theoretical bounds for given sparsity level k.
    Includes QP-related bounds and the worst-case bound for Algorithm 1
    from Belabbas & Wolfe (2009).
    """

    n = data['n']
    frob_norm_AB = data['frob_norm'] # This is sqrt(||AB^T||_F^2)

    # --- Initialize results with NaN ---
    binary_ratio = np.nan
    qp_ratio = np.nan
    vk_ratio = np.nan
    Greedy_bound_ratio = np.nan # This will be computed but ignored later

    # --- Existing QP-related calculations ---
    has_qp_data = all(key in data for key in ['Gab', 'q', 'r', 'trace'])
    if has_qp_data:
        G = data['Gab']
        q_vec = data['q']
        r_val = data['r'] # Interpreted as sum(G)
        TrG = data['trace']
        oneGone = frob_norm_AB**2 # This is ||AB^T||_F^2

        if n > 1 and oneGone > 1e-12:
            # --- QP Ratio ---
            rho_G = TrG / oneGone if oneGone > 1e-12 else 0.0
            beta_k = (k - 1) / (n - 1) if n > 1 else 0.0
            denominator = (beta_k + (1 - beta_k) * rho_G)
            if abs(denominator) > 1e-12:
                gamma = 1.0 / denominator
                qp_bound_sq_ratio = max(0, 1.0 - k * gamma / n)
                qp_ratio = np.sqrt(qp_bound_sq_ratio)
            else:
                qp_ratio = np.nan

            # --- Binary Ratio ---
            alpha_k = k / (n - 1) if n > 1 else 0.0
            binary_bound_sq = max(0, (1.0 - k * 1.0 / n) * ((1.0 - alpha_k) * oneGone + alpha_k * TrG))
            binary_ratio = np.sqrt(binary_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0

            # --- v_k bound (QP Best) ---
            if k > 0 : # CVXPY can fail for k=0 if G_hat_k is not PSD
                G_hat_k = beta_k * G + (1 - beta_k) * np.diag(np.diag(G))
                y = cp.Variable(n)
                constraints = [y >= 0]
                objective = cp.Minimize(0.5 * cp.quad_form(y, G_hat_k) - q_vec.T @ y)
                prob = cp.Problem(objective, constraints)
                try:
                    # Increased iterations and different solver for robustness if needed
                    prob.solve(solver=cp.SCS, verbose=False, eps=1e-7, max_iters=10000)
                    if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                        v_k_bound_sq = max(0, oneGone + (k / n) * 2.0 * prob.value)
                        vk_ratio = np.sqrt(v_k_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0
                    else:
                        # print(f"Warning: v_k bound CVXPY solve failed for k={k} with status: {prob.status}")
                        vk_ratio = np.nan
                except cp.SolverError as e:
                    # print(f"Warning: v_k bound CVXPY SolverError for k={k}: {e}")
                    vk_ratio = np.nan
                except Exception as e:
                    # print(f"Warning: v_k bound CVXPY solve encountered error for k={k}: {e}")
                    vk_ratio = np.nan
            elif k == 0: # For k=0, QP best bound is not well-defined by this formula, should be 1.0 (full error)
                vk_ratio = 1.0 if oneGone > 1e-12 else 0.0


        elif n <= 1: pass # print(f"Warning: Skipping QP bounds for k={k} due to n <= 1.")
        elif oneGone <= 1e-12: pass # print(f"Warning: Skipping QP bounds for k={k} due to ||AB^T||_F^2 approx zero.")

    else:
        # print(f"Warning: QP-related data not found for k={k}. Skipping related bounds.")
        pass

    # --- Algorithm 1 Worst-Case Bound Calculation (Greedy Theoretical - IGNORED LATER) ---
    if 'A' in data and 'B' in data:
        A_data = data['A']; B_orig = data['B'] # B_orig is n x p
        m_A, n_A = A_data.shape; n_B, p_B = B_orig.shape
        if n_A == n and n_B == n:
            if not 0 <= k <= n: Greedy_bound_ratio = np.nan
            elif k == 0: Greedy_bound_ratio = 1.0 if frob_norm_AB > 1e-12 else 0.0
            elif k == n: Greedy_bound_ratio = 0.0
            else:
                norms_A_sq = np.sum(A_data * A_data, axis=0)
                norms_B_sq = np.sum(B_orig * B_orig, axis=1)
                T = norms_A_sq * norms_B_sq
                indices_sorted_by_T = np.argsort(T)
                J_complement = indices_sorted_by_T[:n-k]
                sum_T_complement = np.sum(T[J_complement])
                Greedy_bound_val_sq = max(0, sum_T_complement)
                Greedy_bound_val = np.sqrt(Greedy_bound_val_sq)
                if frob_norm_AB > 1e-12: Greedy_bound_ratio = Greedy_bound_val / frob_norm_AB
                else: Greedy_bound_ratio = 0.0 if Greedy_bound_val < 1e-12 else np.inf
        else: Greedy_bound_ratio = np.nan
    else: Greedy_bound_ratio = np.nan

    # --- Return Dictionary ---
    return {
        'binary_ratio': binary_ratio,
        'qp_ratio': qp_ratio,
        'qp_ratio_best': vk_ratio,
        'Greedy_ratio': Greedy_bound_ratio
    }

# --- STANDARD BOUNDS FUNCTION ---
def compute_standard_bounds(A: np.ndarray, B: np.ndarray, k: int, frob_ABt_sq: float) -> Dict[str, float]:
    """
    Computes standard theoretical bounds for AMM/Sketching.
    Returns squared error bounds relative to ||AB^T||_F^2.
    """
    m, n = A.shape
    p, n2 = B.shape
    bound_leverage_exp_sq_ratio = np.nan
    bound_sketching_simple_sq_ratio = np.nan

    if n != n2 or n == 0 or frob_ABt_sq < 1e-20 : # Allow k=0 for leverage score if frob_ABt_sq is non-zero
        if k == 0 and frob_ABt_sq > 1e-20 : # Leverage score for k=0 is not typically defined this way, error is full
             bound_leverage_exp_sq_ratio = 1.0 # Full error
             bound_sketching_simple_sq_ratio = np.inf # Sketching bound is infinite
        return {
            'Bound (Leverage Score Exp.)': bound_leverage_exp_sq_ratio,
            'Bound (Sketching Simple)': bound_sketching_simple_sq_ratio
        }
    if k == 0 : # If k=0 but other conditions met (n>0, frob_ABt_sq >0)
        bound_leverage_exp_sq_ratio = 1.0
        bound_sketching_simple_sq_ratio = np.inf
        return {
            'Bound (Leverage Score Exp.)': bound_leverage_exp_sq_ratio,
            'Bound (Sketching Simple)': bound_sketching_simple_sq_ratio
        }


    A_f64 = A.astype(np.float64)
    B_f64 = B.astype(np.float64)

    try:
        norms_A = np.linalg.norm(A_f64, axis=0)
        norms_B = np.linalg.norm(B_f64, axis=0)
        sum_prod_norms = np.sum(norms_A * norms_B)
        sum_prod_norms = max(0, sum_prod_norms)

        # Original formula: ( (sum ||A_i|| ||B_i||)^2 - ||AB^T||_F^2 ) / k
        # This can be negative if ||AB^T||_F^2 is large.
        # A more stable interpretation or alternative bound might be needed if this is problematic.
        # For now, stick to the direct formula and clip at 0.
        expected_error_sq_abs = (sum_prod_norms**2 - frob_ABt_sq) / k # Absolute expected error squared
        bound_leverage_exp_sq = max(0, expected_error_sq_abs)
        bound_leverage_exp_sq_ratio = bound_leverage_exp_sq / frob_ABt_sq

    except Exception as e:
        warnings.warn(f"Failed to compute Leverage Score bound for k={k}: {e}", RuntimeWarning)
        bound_leverage_exp_sq_ratio = np.nan

    try:
        frob_A_sq = frob_norm_sq(A_f64)
        frob_B_sq = frob_norm_sq(B_f64)
        sketching_bound_sq = (frob_A_sq * frob_B_sq) / k
        bound_sketching_simple_sq_ratio = max(0, sketching_bound_sq / frob_ABt_sq)

    except Exception as e:
        warnings.warn(f"Failed to compute Simple Sketching bound for k={k}: {e}", RuntimeWarning)
        bound_sketching_simple_sq_ratio = np.nan


    return {
        'Bound (Leverage Score Exp.)': bound_leverage_exp_sq_ratio,
        'Bound (Sketching Simple)': bound_sketching_simple_sq_ratio
        }


# --- Algorithm Implementations ---

def run_leverage_score_sampling(A: np.ndarray, B: np.ndarray, k: int, optimal: bool = True, replacement: bool = False) -> np.ndarray:
    """Approximates AB^T using Leverage Score sampling (formerly AMM/RMM)."""
    m, n = A.shape; p, n2 = B.shape # B is p x n
    if n != n2: raise ValueError("Matrices A and B must have the same number of columns (n)")
    if k <= 0: raise ValueError(f"k must be positive for Leverage Score Sampling, got {k}")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)

    if not replacement and k > n: k = n
    if k == 0: return np.zeros((m, p), dtype=A.dtype) # Should be caught by k<=0 above

    if optimal:
        norms_A_euc = np.linalg.norm(A, axis=0) # Euclidean norms of columns of A
        norms_B_euc = np.linalg.norm(B, axis=0) # Euclidean norms of columns of B
        lev_scores = norms_A_euc * norms_B_euc
        total_lev_score = np.sum(lev_scores)
        if total_lev_score < 1e-12:
            probs = np.ones(n) / n if n > 0 else np.array([])
        else:
            probs = lev_scores / total_lev_score
    else: # Uniform sampling
        probs = np.ones(n) / n if n > 0 else np.array([])

    if n > 0:
        probs = np.maximum(probs, 1e-12); probs /= probs.sum()
    else: return np.zeros((m, p), dtype=A.dtype)

    selected_indices = np.random.choice(n, size=k, replace=replacement, p=probs)

    scaling = 1.0 / np.sqrt(k * probs[selected_indices])
    A_reduced = A[:, selected_indices] * scaling
    B_reduced = B[:, selected_indices] * scaling

    return A_reduced @ B_reduced.T

def run_countsketch(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    """Approximates AB^T using CountSketch."""
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError("k must be positive for CountSketch")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    h = np.random.randint(0, k, size=n)
    g = np.random.choice([-1.0, 1.0], size=n)
    SA = np.zeros((m, k), dtype=A.dtype)
    SB = np.zeros((p, k), dtype=B.dtype)
    for j in range(n):
        hash_idx = h[j]; sign = g[j]
        SA[:, hash_idx] += sign * A[:, j]
        SB[:, hash_idx] += sign * B[:, j]
    return SA @ SB.T

def run_gaussian_projection(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    """Approximates AB^T using Gaussian random projection."""
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError("k must be positive for Gaussian Projection")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    S = np.random.randn(k, n) / np.sqrt(k) # S is k x n
    A_proj = A @ S.T # A_proj is m x k
    B_proj = B @ S.T # B_proj is p x k
    return A_proj @ B_proj.T

def run_greedy_selection_omp(A: np.ndarray, B: np.ndarray, k: int, ABt_exact: Optional[np.ndarray] = None) -> np.ndarray:
    """Approximates AB^T using Orthogonal Matching Pursuit style greedy column selection."""
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if not (1 <= k <= n): raise ValueError(f"k={k} must be between 1 and n={n} for Greedy OMP")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)

    if ABt_exact is None: ABt_exact = A @ B.T

    selected_indices = []
    remaining_indices = list(range(n))
    current_approx = np.zeros_like(ABt_exact, dtype=np.float64)
    residual = ABt_exact.astype(np.float64).copy()

    A_f64 = A.astype(np.float64)
    B_f64 = B.astype(np.float64)

    for t in range(k):
        if not remaining_indices: break
        best_corr = -1
        best_rem_idx_in_list = -1

        for list_idx, original_col_idx in enumerate(remaining_indices):
             outer_prod_j = np.outer(A_f64[:, original_col_idx], B_f64[:, original_col_idx])
             correlation = np.sum(residual * outer_prod_j)
             abs_correlation = np.abs(correlation)
             if abs_correlation > best_corr:
                 best_corr = abs_correlation
                 best_rem_idx_in_list = list_idx

        if best_rem_idx_in_list == -1: break
        best_original_idx = remaining_indices.pop(best_rem_idx_in_list)
        selected_indices.append(best_original_idx)

        A_selected = A_f64[:, selected_indices]
        B_selected = B_f64[:, selected_indices]
        current_approx = A_selected @ B_selected.T
        residual = ABt_exact - current_approx

    if not selected_indices: return np.zeros((m, p), dtype=A.dtype)
    A_final = A[:, selected_indices]
    B_final = B[:, selected_indices]
    return A_final @ B_final.T

def fast_walsh_hadamard_transform_manual(X: np.ndarray, axis: int = -1) -> np.ndarray:
    """Manual FWHT implementation using recursion."""
    Y = np.asarray(X, dtype=float)
    n = Y.shape[axis]
    original_axis = axis
    if axis < 0: axis = Y.ndim + axis

    if not (n > 0 and (n & (n - 1) == 0)):
        raise ValueError(f"Input size along axis {original_axis} must be a power of 2, got {n}")

    if n == 1: return Y

    idx_even = [slice(None)] * Y.ndim; idx_odd = [slice(None)] * Y.ndim
    idx_even[axis] = slice(None, None, 2); idx_odd[axis] = slice(1, None, 2)
    X_even = Y[tuple(idx_even)]; X_odd = Y[tuple(idx_odd)]

    H_even = fast_walsh_hadamard_transform_manual(X_even, axis=axis)
    H_odd = fast_walsh_hadamard_transform_manual(X_odd, axis=axis)

    result = np.empty_like(Y)
    idx_first_half = [slice(None)] * Y.ndim; idx_second_half = [slice(None)] * Y.ndim
    idx_first_half[axis] = slice(0, n // 2); idx_second_half[axis] = slice(n // 2, n)

    result[tuple(idx_first_half)] = H_even + H_odd
    result[tuple(idx_second_half)] = H_even - H_odd

    return result

def pad_matrix(A: np.ndarray, axis: int = 1) -> Tuple[np.ndarray, int]:
    """Pads matrix A along a specified axis to have size = next power of 2."""
    target_shape = list(A.shape)
    n_orig = target_shape[axis]
    if n_orig == 0: return A, 0
    next_pow_2 = 1 << (n_orig - 1).bit_length() if n_orig > 0 else 0
    if next_pow_2 > n_orig:
        pad_width = next_pow_2 - n_orig
        padding_spec = [(0, 0)] * A.ndim
        padding_spec[axis] = (0, pad_width)
        A_padded = np.pad(A, pad_width=padding_spec, mode='constant', constant_values=0)
        return A_padded, n_orig
    else:
        return A, n_orig

def run_srht_new(A: np.ndarray, B: np.ndarray, k: int, optimal_sampling: bool = False) -> np.ndarray:
    """Computes AB^T approximation using SRHT with manual FWHT."""
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch")
    if k <= 0: raise ValueError(f"k must be positive for SRHT, got {k}")
    if n == 0: return np.zeros((m, p), dtype=A.dtype)

    AB_stacked = np.vstack((A, B))
    AB_padded, n_orig = pad_matrix(AB_stacked, axis=1)
    N_padded = AB_padded.shape[1]
    if N_padded == 0: return np.zeros((m, p), dtype=A.dtype)

    k_actual = min(k, N_padded)
    if k_actual < k: warnings.warn(f"SRHT sampling k reduced from {k} to {k_actual} due to padded dim N={N_padded}", RuntimeWarning)
    if k_actual == 0: return np.zeros((m, p), dtype=A.dtype)

    D_diag = np.random.choice([-1.0, 1.0], size=N_padded)
    AB_signed = AB_padded * D_diag

    try:
        AB_transformed_unnorm = fast_walsh_hadamard_transform_manual(AB_signed, axis=1)
        AB_transformed = AB_transformed_unnorm / np.sqrt(N_padded)
    except ValueError as ve:
        raise RuntimeError(f"Manual FWHT failed: {ve}. Input shape: {AB_signed.shape}")
    except Exception as e:
        raise RuntimeError(f"Unexpected error in Manual FWHT: {e}\n{traceback.format_exc()}")

    sampled_indices_padded = np.random.choice(N_padded, size=k_actual, replace=False)
    scaling_factor = np.sqrt(N_padded / k_actual)
    AB_sampled = AB_transformed[:, sampled_indices_padded] * scaling_factor

    A_reduced = AB_sampled[:m, :]
    B_reduced = AB_sampled[m:, :]

    return A_reduced @ B_reduced.T


# --- Optimal v_k* Calculation ---
def compute_optimal_vk_star(A: np.ndarray, B: np.ndarray, k: int,
                            ABt_exact: np.ndarray,
                            threshold: int = 100000000): # Default threshold changed to a more practical value
    """
    Computes the optimal k-term approximation error ||AB^T - sum_{i in S_k} w_i^* A_i B_i^T||_F^2,
    where S_k is the optimal set of k column indices and w_i^* are the optimal weights
    for that set.
    """
    m, n_cols = A.shape # n_cols is the original 'n' dimension

    # Ensure inputs are float64 for precision
    A_f64 = A.astype(np.float64)
    B_f64 = B.astype(np.float64) # B is p x n_cols
    ABt_exact_f64 = ABt_exact.astype(np.float64)

    frob_ABt_sq_local = frob_norm_sq(ABt_exact_f64)

    if k < 0 or k > n_cols:
        warnings.warn(f"k={k} is out of bounds for n_cols={n_cols}.", RuntimeWarning)
        return np.nan
    if k == 0:
        return frob_ABt_sq_local
    if k == n_cols: # If all columns are selected, optimal weights should make error zero
        return 0.0
    if n_cols == 0: # No columns to select
        return frob_ABt_sq_local

    try:
        num_combinations = math.comb(n_cols, k)
    except ValueError: # Should be caught by k < 0 or k > n_cols
        return np.nan

    if num_combinations > threshold:
        # This print can be verbose if called many times
        # print(f"Skipping optimal_vk_star for k={k}, n={n_cols}: num_combinations {num_combinations} > threshold {threshold}")
        return np.nan # Signal that computation was skipped

    # Precomputations for efficiency:
    # 1. G_full_ij = <A_i B_i^T, A_j B_j^T> = (A_i^T A_j) * (B_i^T B_j)
    #    Let a_i be col i of A, b_i be col i of B.
    #    G_full_ij = (a_i.T @ a_j) * (b_i.T @ b_j)
    AtA = A_f64.T @ A_f64  # n_cols x n_cols
    BtB = B_f64.T @ B_f64  # n_cols x n_cols
    # Element-wise product for the Gram matrix components
    # G_full_coeffs_ij = (A_i^T A_j) * (B_i^T B_j)
    # This is G in your QP bound notation if A, B are the matrices there.
    G_full_coeffs = AtA * BtB # n_cols x n_cols

    # 2. RHS_full_i = <AB^T, A_i B_i^T> = trace((A_i B_i^T)^T AB^T)
    #              = trace(B_i A_i^T AB^T) = A_i^T (AB^T) B_i
    RHS_full_coeffs = np.zeros(n_cols, dtype=np.float64)
    for i in range(n_cols):
        a_i = A_f64[:, i]
        b_i = B_f64[:, i] # B_f64 is p x n_cols, so B_f64[:,i] is its i-th column
        RHS_full_coeffs[i] = a_i.T @ ABt_exact_f64 @ b_i

    min_error_sq = frob_ABt_sq_local # Initialize with full error (error if k=0 terms are chosen)

    for indices_tuple in itertools.combinations(range(n_cols), k):
        indices = list(indices_tuple) # Current set S_k of k indices

        # Extract the k x k Gram matrix G_S_k for the current combination
        # G_S_k_uv = <A_{s_u}B_{s_u}^T, A_{s_v}B_{s_v}^T>
        Gram_S_k = G_full_coeffs[np.ix_(indices, indices)]

        # Extract the k x 1 RHS vector c_S_k for the current combination
        # c_S_k_u = <AB^T, A_{s_u}B_{s_u}^T>
        RHS_S_k = RHS_full_coeffs[indices]

        try:
            # Solve G_S_k * w = c_S_k for optimal weights w
            # Using lstsq for robustness against singularity
            w_opt = np.linalg.lstsq(Gram_S_k, RHS_S_k, rcond=None)[0]

            # Minimum error for this S_k is ||AB^T||_F^2 - c_S_k^T * G_S_k^+ * c_S_k
            # which is ||AB^T||_F^2 - c_S_k^T * w_opt
            current_error_sq = frob_ABt_sq_local - np.dot(RHS_S_k, w_opt)

            # Clamp to non-negative due to potential floating point inaccuracies
            current_error_sq = max(0, current_error_sq)

        except np.linalg.LinAlgError:
            warnings.warn(f"Unexpected LinAlgError with lstsq for k={k}, indices={indices}. Skipping combination.", RuntimeWarning)
            continue


        if current_error_sq < min_error_sq:
            min_error_sq = current_error_sq

    return min_error_sq


# --- Experiment Runner ---

class ExperimentFailureError(Exception):
    """Custom exception for failures within run_experiment_iterations."""
    pass

def run_experiment_iterations(A_matrix: np.ndarray, B_matrix: np.ndarray, k_values_list: List[int],
                              num_trials: int = 100,
                              compute_vk_star_flag: bool = False, vk_star_comb_threshold: int = 100000,
                              run_algorithms_flag: bool = True, run_bounds_flag: bool = True):
    """
    Runs algorithms and computes bounds for different k values.
    Uses user's QP bounds, standard bounds, correct naming, and strict error handling.
    "Your Bound (Binary)" is post-processed to be non-increasing.
    Returns results dict or raises ExperimentFailureError on failure.
    (Formerly run_experiments_flexible)
    """
    A = A_matrix
    B = B_matrix # B is p x n
    k_values = np.array(k_values_list, dtype=int)
    m, n = A.shape

    k_values = np.unique(k_values[(k_values >= 0) & (k_values <= n)])


    results = {
        'k': k_values, 'n': n,
        'Your Bound (Binary)': np.full(len(k_values), np.nan),
        'Your Bound (QP Analytical)': np.full(len(k_values), np.nan),
        'Your Bound (QP CVXPY Best)': np.full(len(k_values), np.nan),
        'Bound (Leverage Score Exp.)': np.full(len(k_values), np.nan),
        'Bound (Sketching Simple)': np.full(len(k_values), np.nan),
        'Error Leverage Score (Actual)': np.full(len(k_values), np.nan),
        'Error CountSketch (Actual)': np.full(len(k_values), np.nan),
        'Error SRHT (Actual)': np.full(len(k_values), np.nan),
        'Error Gaussian (Actual)': np.full(len(k_values), np.nan),
        'Error Greedy OMP (Actual)': np.full(len(k_values), np.nan),
        'Optimal Error v_k^*': np.full(len(k_values), np.nan),
        'Frob_ABt_Sq': np.nan, 'Rho_G': np.nan # Corrected key name
    }

    if len(k_values) == 0: warnings.warn("No valid k values."); return results
    if not any([run_algorithms_flag, run_bounds_flag, compute_vk_star_flag]):
        warnings.warn("All computations disabled."); return results

    try:
        A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64)
        ABt_exact = A_f64 @ B_f64.T
        frob_ABt_sq = frob_norm_sq(ABt_exact)
        frob_ABt = np.sqrt(frob_ABt_sq) if frob_ABt_sq > 1e-20 else 0.0
        results['Frob_ABt_Sq'] = frob_ABt_sq # Corrected key name
        results['Rho_G'] = calculate_rho_g(A_f64, B_f64)

        if frob_ABt_sq < 1e-20:
            warnings.warn(f"||AB^T||_F^2 is near zero ({frob_ABt_sq:.2e}).")
            for key in results:
                if 'Error' in key or 'Bound (Binary)' in key or 'Bound (QP' in key or 'Optimal' in key or 'Bound (Leverage' in key:
                    results[key].fill(0.0)
                elif 'Bound (Sketching Simple)' in key:
                    results[key].fill(np.inf)
            return results

        if n > 0:
            AtA = A_f64.T @ A_f64
            BtB = B_f64.T @ B_f64
            G = AtA * BtB
            trace_G = np.trace(G)
            sum_G = np.sum(G)
            q_vec = G @ np.ones(n)
        else: G, trace_G, sum_G, q_vec = np.array([[]]), 0.0, 0.0, np.array([])

        bound_data_dict = {
            'n': n, 'frob_norm': frob_ABt, 'A': A_f64, 'B': B_f64.T, # B_f64.T is n x p
            'Gab': G, 'q': q_vec, 'r': sum_G, 'trace': trace_G
        }
    except Exception as e:
        raise ExperimentFailureError(f"Initial setup (n={n}): {e}\n{traceback.format_exc()}")

    if compute_vk_star_flag:
        print(f"Computing Optimal v_k* (threshold={vk_star_comb_threshold}, n={n})...")
        last_abs_vk_star_error = frob_ABt_sq

        for i, k_val_vk in enumerate(tqdm(k_values, desc="Optimal v_k*", leave=False)):
            try:
                optimal_abs_error_sq_current_k = compute_optimal_vk_star(A_f64, B_f64, k_val_vk, ABt_exact, vk_star_comb_threshold)
                if not np.isnan(optimal_abs_error_sq_current_k):
                    if k_val_vk > 0 and optimal_abs_error_sq_current_k > last_abs_vk_star_error + 1e-9 * frob_ABt_sq :
                        warnings.warn(f"Optimal v_k* (absolute error) potentially non-monotonic at k={k_val_vk}: {optimal_abs_error_sq_current_k:.4e} > prev k's {last_abs_vk_star_error:.4e}.", RuntimeWarning)
                    results['Optimal Error v_k^*'][i] = optimal_abs_error_sq_current_k / frob_ABt_sq
                    last_abs_vk_star_error = optimal_abs_error_sq_current_k
                else:
                    if i > 0: results['Optimal Error v_k^*'][i] = results['Optimal Error v_k^*'][i-1]
                    else: results['Optimal Error v_k^*'][i] = 1.0
            except Exception as e:
                 raise ExperimentFailureError(f"v_k^* for k={k_val_vk}: {e}\n{traceback.format_exc()}")
        results['Optimal Error v_k^*'] = np.maximum(0, results['Optimal Error v_k^*'])


    print(f"Computing Bounds and Algorithm Errors (num_trials={num_trials})...")
    for i, k_val_main in enumerate(tqdm(k_values, desc=f"Exp (n={n}, RhoG={results['Rho_G']:.2f})", leave=False)):
        if run_bounds_flag:
            try:
                user_bounds = compute_theoretical_bounds(bound_data_dict, k_val_main)
                results['Your Bound (Binary)'][i] = user_bounds['binary_ratio']**2 if not np.isnan(user_bounds['binary_ratio']) else np.nan
                results['Your Bound (QP Analytical)'][i] = user_bounds['qp_ratio']**2 if not np.isnan(user_bounds['qp_ratio']) else np.nan
                results['Your Bound (QP CVXPY Best)'][i] = user_bounds['qp_ratio_best']**2 if not np.isnan(user_bounds['qp_ratio_best']) else np.nan
            except Exception as e:
                 raise ExperimentFailureError(f"compute_theoretical_bounds k={k_val_main}: {e}\n{traceback.format_exc()}")
            try:
                std_bounds = compute_standard_bounds(A_f64, B_f64, k_val_main, frob_ABt_sq)
                results['Bound (Leverage Score Exp.)'][i] = std_bounds['Bound (Leverage Score Exp.)']
                results['Bound (Sketching Simple)'][i] = std_bounds['Bound (Sketching Simple)']
            except Exception as e:
                raise ExperimentFailureError(f"compute_standard_bounds k={k_val_main}: {e}\n{traceback.format_exc()}")

        if run_algorithms_flag:
            if k_val_main == 0:
                 results['Error Leverage Score (Actual)'][i] = 1.0
                 results['Error CountSketch (Actual)'][i] = 1.0
                 results['Error SRHT (Actual)'][i] = 1.0
                 results['Error Gaussian (Actual)'][i] = 1.0
                 results['Error Greedy OMP (Actual)'][i] = 1.0
                 continue

            errors_ls, errors_cs, errors_srht, errors_gauss = [], [], [], []
            trial_failed = False
            for _ in range(num_trials):
                try: errors_ls.append(frob_norm_sq(ABt_exact - run_leverage_score_sampling(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"LevScore trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_cs.append(frob_norm_sq(ABt_exact - run_countsketch(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"CS trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_srht.append(frob_norm_sq(ABt_exact - run_srht_new(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"SRHT trial k={k_val_main}: {e}",RuntimeWarning); break
                try: errors_gauss.append(frob_norm_sq(ABt_exact - run_gaussian_projection(A_f64, B_f64, k_val_main)))
                except Exception as e: trial_failed=True; warnings.warn(f"Gauss trial k={k_val_main}: {e}",RuntimeWarning); break

            if trial_failed: raise ExperimentFailureError(f"Rand algo trial failed k={k_val_main}.")

            results['Error Leverage Score (Actual)'][i] = np.mean(errors_ls) / frob_ABt_sq
            results['Error CountSketch (Actual)'][i] = np.mean(errors_cs) / frob_ABt_sq
            results['Error SRHT (Actual)'][i] = np.mean(errors_srht) / frob_ABt_sq
            results['Error Gaussian (Actual)'][i] = np.mean(errors_gauss) / frob_ABt_sq

            if 1 <= k_val_main <= n:
                try:
                     omp_err_sq = frob_norm_sq(ABt_exact - run_greedy_selection_omp(A_f64, B_f64, k_val_main, ABt_exact))
                     results['Error Greedy OMP (Actual)'][i] = omp_err_sq / frob_ABt_sq
                except Exception as e: raise ExperimentFailureError(f"Greedy OMP k={k_val_main}: {e}\n{traceback.format_exc()}")
            else: results['Error Greedy OMP (Actual)'][i] = np.nan

    # Post-process 'Your Bound (Binary)' to ensure it's non-increasing
    if 'Your Bound (Binary)' in results and len(results['Your Bound (Binary)']) > 0:
        raw_binary_bounds_sq = results['Your Bound (Binary)'].copy()
        processed_binary_bounds_sq = np.full_like(raw_binary_bounds_sq, np.nan)
        current_min_sq = np.inf
        for i in range(len(raw_binary_bounds_sq)):
            if not np.isnan(raw_binary_bounds_sq[i]):
                current_min_sq = min(current_min_sq, raw_binary_bounds_sq[i])
            if np.isfinite(current_min_sq):
                processed_binary_bounds_sq[i] = current_min_sq
        results['Your Bound (Binary)'] = processed_binary_bounds_sq

    for key in results:
         if key not in ['k', 'n', 'Frob_ABt_Sq', 'Rho_G']: # Corrected key name
             valid_mask = ~np.isnan(results[key])
             if np.any(valid_mask):
                 results[key][valid_mask] = np.clip(results[key][valid_mask], 0, 1000000.0)

    return results


# ---  PLOTTING FUNCTION ---
def plot_experiment_results(results_by_name: Dict[str, Dict],
                            main_plot_filename_base: str = "experiment_combined",
                            plot_styles_dict: Dict = PLOT_STYLES,
                            y_axis_label_flag: bool = False,
                            figure_super_title: Optional[str] = None):
    """
    Plot the results of the experiment with publication-quality formatting.
    (Formerly plot_combined_experiment)

    Parameters:
        results_by_name (dict): Dictionary of results, keys are experiment names.
        main_plot_filename_base (str): Base filename for saving the plot.
        plot_styles_dict (dict): Dictionary of plot styles.
        y_axis_label_flag (bool): Whether to include the y-axis label on the first subplot.
        figure_super_title (str, optional): Super title for the entire figure.
    Returns:
        tuple: (filepath_combined, filepath_legend) paths to saved plot files
    """
    valid_results = {name: res for name, res in results_by_name.items() if res and 'k' in res and len(res['k']) > 0}
    num_plots = len(valid_results)
    if num_plots == 0:
        print("No valid results to plot.")
        return None, None

    # Create figure with appropriate size
    fig, axes = plt.subplots(1, num_plots, figsize=(7 * num_plots, 8), sharey=True)
    if num_plots == 1:
        axes = [axes]

    plot_handles, collected_plot_labels_ordered = [], []
    plot_labels_set = set()
    global_min_y, global_max_y = np.inf, -np.inf
    any_plot_successful = False

    for i, (name, results) in enumerate(valid_results.items()):
        ax = axes[i]
        k_values = results['k']
        n_dim = results['n']
        actual_rho = results.get('Rho_G', np.nan)

        # Format title based on the experiment type
        title_text =  r"$\rho_G \approx {:.2f}$".format(actual_rho) + f" (n={n_dim})"

        # Filter k values > 0
        plot_mask = k_values > 0
        if not np.any(plot_mask) and len(k_values) > 0:
            ax.set_title(title_text + "\n(No k>0 data)", fontsize=22, fontweight='bold')
            ax.text(0.5, 0.5, 'No data for k>0', ha='center', va='center', transform=ax.transAxes, fontsize=20)
            continue
        elif not np.any(plot_mask):
            ax.set_title(title_text + "\n(No data)", fontsize=22, fontweight='bold')
            ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes, fontsize=20)
            continue

        k_plot_abs = k_values[plot_mask]
        k_plot_ratio = k_plot_abs / n_dim if n_dim > 0 else np.zeros_like(k_plot_abs, dtype=float)

        subplot_min_y, subplot_max_y = np.inf, -np.inf
        plotted_something_on_subplot = False

        # Plot each method
        for method_key, style_params in plot_styles_dict.items():
            if method_key in results:
                y_values_all = results[method_key]
                if len(y_values_all) == len(k_values):
                    y_values = y_values_all[plot_mask]
                    valid_mask = ~np.isnan(y_values) & np.isfinite(y_values) & (y_values > 1e-12)
                    current_k_plot_ratio = k_plot_ratio[valid_mask]
                    y_plot_valid = y_values[valid_mask]

                    if len(current_k_plot_ratio) > 0:
                        current_style = style_params.copy()
                        label_text = current_style.pop('label', method_key)
                        line = ax.plot(current_k_plot_ratio, y_plot_valid, **current_style)
                        plotted_something_on_subplot = True

                        if label_text not in plot_labels_set:
                            plot_handles.append(line[0])
                            collected_plot_labels_ordered.append(label_text)
                            plot_labels_set.add(label_text)

                        if len(y_plot_valid) > 0:
                            subplot_min_y = min(subplot_min_y, np.min(y_plot_valid))
                            subplot_max_y = max(subplot_max_y, np.max(y_plot_valid))

        # Configure subplot
        if plotted_something_on_subplot:
            any_plot_successful = True
            if np.isfinite(subplot_min_y):
                global_min_y = min(global_min_y, subplot_min_y)
            if np.isfinite(subplot_max_y):
                global_max_y = max(global_max_y, subplot_max_y)

            ax.set_title(title_text, fontsize=22, fontweight='bold', pad=15)
            ax.grid(True, linestyle=':', alpha=0.5)
            ax.set_yscale('log')
            ax.tick_params(axis='both', which='major', labelsize=20)

            # X-axis label
            if i == num_plots // 2: # Centered X-axis label
                ax.set_xlabel("Proportion of Samples (k/n)", fontsize=22, fontweight='bold', labelpad=10)

            # Set x-axis limits
            if len(k_plot_ratio) > 0 and n_dim > 0 :
                min_x_val = np.min(k_plot_ratio[k_plot_ratio > 0]) if np.any(k_plot_ratio > 0) else 0.01/n_dim
                max_x_val = np.max(k_plot_ratio) if np.any(k_plot_ratio > 0) else 1.0
                ax.set_xlim(left=min_x_val * 0.95 if min_x_val > 0 else 0.001, right=max_x_val * 1.05 if max_x_val > 0 else 1.0)
            else:
                ax.set_xlim(left=0.001, right=1.0)
        else:
            ax.set_title(title_text + "\n(No valid data > 0 for k>0)", fontsize=22, fontweight='bold')
            ax.text(0.5, 0.5, 'No valid data > 0', ha='center', va='center', transform=ax.transAxes, fontsize=20)


    if not any_plot_successful:
        print("Skipping plot: No valid data plotted.")
        plt.close(fig)
        return None, None

    # Add figure title if provided
    if figure_super_title:
        fig.suptitle(figure_super_title, fontsize=26, y=0.98, fontweight='bold')

    # Add y-axis label if requested
    if y_axis_label_flag and num_plots > 0:
        axes[0].set_ylabel("Relative Squared Error (Log Scale)", fontsize=22, fontweight='bold', labelpad=15)

    # Set global y-axis limits
    if np.isfinite(global_min_y) and np.isfinite(global_max_y) and global_max_y > global_min_y:
        y_lower = max(1e-9, global_min_y * 0.5)
        y_upper = global_max_y * 1.5
        if y_upper / y_lower < 100: # Ensure a reasonable range for log scale
            y_upper = y_lower * 100
        if num_plots > 0: axes[0].set_ylim(bottom=y_lower, top=y_upper)
    elif num_plots > 0:
        axes[0].set_ylim(bottom=1e-9, top=1.5)


    # Adjust layout and save
    plt.tight_layout(rect=[0, 0.03, 1, 0.95 if figure_super_title else 0.98]) # Adjust rect based on suptitle
    plot_dir = "plots"
    # os.makedirs(plot_dir, exist_ok=True) # Already created at the top

    filepath_combined = os.path.join(plot_dir, f"{main_plot_filename_base}_plots.png")
    try:
        fig.savefig(filepath_combined, bbox_inches='tight', dpi=300)
        print(f"Saved: {filepath_combined}")
    except Exception as e:
        print(f"Error saving {filepath_combined}: {e}")

    plt.show()
    plt.close(fig)

    # Create separate legend
    if plot_handles and collected_plot_labels_ordered:
        fig_legend = plt.figure(figsize=(12, 2.5)) # Adjusted for potentially fewer items
        ax_legend = fig_legend.add_subplot(111)

        # Determine number of columns dynamically, e.g., 3 or 4 based on label count
        num_legend_cols = min(len(collected_plot_labels_ordered), 3)

        leg = ax_legend.legend(plot_handles, collected_plot_labels_ordered,
                              loc='center', ncol=num_legend_cols, frameon=True,
                              fontsize=20, framealpha=0.7)

        for text in leg.get_texts():
            text.set_fontweight('bold')

        ax_legend.axis('off')
        filepath_legend = os.path.join(plot_dir, f"{main_plot_filename_base}_legend.png")
        try:
            fig_legend.savefig(filepath_legend, bbox_inches='tight', dpi=300)
            print(f"Saved: {filepath_legend}")
        except Exception as e:
            print(f"Error saving legend {filepath_legend}: {e}")

        plt.close(fig_legend)
        return filepath_combined, filepath_legend
    else:
        print("No handles/labels for separate legend.")
        return filepath_combined, None